#include "FramewiseLocalStrategy.h"

using namespace MMM;

FramewiseLocalStrategy::FramewiseLocalStrategy(const std::map<float, std::map<std::string, Eigen::Vector3f> > &labeledMarkerData, MotionPtr outputMotion, ModelPtr outputModel, const std::vector<std::string> &joints, const std::map<std::string, std::string> &markerMapping, const std::map<std::string, float> &markerWeights) :
    ConvertingStrategy(labeledMarkerData, outputMotion, outputModel, joints, markerMapping, markerWeights),
    cancelled(false)
{
}

void FramewiseLocalStrategy::cancel() {
    cancelled = true;
}

float FramewiseLocalStrategy::getCurrentTimestep() {
    return currentTimestep;
}

void FramewiseLocalStrategy::convert() {

    unsigned int frameNum = 0;
    std::map<float, std::vector<double>> optimizedFrames;
    for (const auto &labeledMarker : labeledMarkerData) {
        if (cancelled) throw MMM::Exception::ForcedCancelException();

        // Build initial configuration for optimization
        std::vector<double> configuration;
        if (frameNum > 0) {
            KinematicSensorMeasurementPtr kinematicSensorMeasurement = outputKinematicSensor->getDerivedMeasurement(currentTimestep);
            ModelPoseSensorMeasurementPtr modelPoseSensorMeasurement = outputModelPoseSensor->getDerivedMeasurement(currentTimestep);

            Eigen::Vector3f rootPos = modelPoseSensorMeasurement->getRootPosition();
            configuration.push_back(rootPos[0]); configuration.push_back(rootPos[1]); configuration.push_back(rootPos[2]);

            Eigen::Vector3f rootRot = modelPoseSensorMeasurement->getRootRotation();
            configuration.push_back(rootRot[0]); configuration.push_back(rootRot[1]); configuration.push_back(rootRot[2]);

            for (int i = 0; i < kinematicSensorMeasurement->getJointAngles().rows(); ++i) {
                configuration.push_back(kinematicSensorMeasurement->getJointAngles()[i]);
            }
        } else {
            configuration = std::vector<double>(frameDimension, 0.0);

            // Set joints to minimum value if minimum value is higher than zero, or to maximum value if maximum value is lower than zero
            int i = 0;
            for (const auto &joint : joints) {
                JointInfo jointInfo = outputModel->getModelNode(joint)->joint;

                float value = 0.0f;
                if (jointInfo.limitLo > 0) value = jointInfo.limitLo;
                if (jointInfo.limitHi < 0) value = jointInfo.limitHi;

                configuration[6 + i] = value;
                i++;
            }
        }
        currentTimestep = labeledMarker.first;

        // Initialize optimization
        nlopt::opt optimizer(nloptAlgorithm, frameDimension);
        optimizer.set_min_objective(ConvertingStrategy::objectiveFunctionWrapperStatic, this);
        optimizer.set_ftol_abs(0.0001);

        setOptimizationBounds(optimizer);

        // Run optimization
        double objectiveValue;
        try {
            nlopt::result resultCode = optimizer.optimize(configuration, objectiveValue);
            MMM_INFO << "Optimization for frame " << frameNum << " finished with code " << resultCode << ". " << std::endl;
        }
        catch (nlopt::roundoff_limited&) {
            MMM_INFO << "Optimization for frame " << frameNum << " finished by throwing nlopt::roundoff_limited (the result should be usable)." << std::endl;
        }
        optimizedFrames[currentTimestep] = configuration;
        // Create ModelPoseSensorMeasurement
        Eigen::Vector3f rootPos;
        rootPos[0] = configuration[0];
        rootPos[1] = configuration[1];
        rootPos[2] = configuration[2];
        Eigen::Vector3f rootRot;
        rootRot[0] = configuration[3];
        rootRot[1] = configuration[4];
        rootRot[2] = configuration[5];

        if (rootRot[0] > M_PI)
            rootRot[0] -= 2 * M_PI;
        if (rootRot[0] < -M_PI)
            rootRot[0] += 2 * M_PI;
        if (rootRot[1] > M_PI)
            rootRot[1] -= 2 * M_PI;
        if (rootRot[1] < -M_PI)
            rootRot[1] += 2 * M_PI;
        if (rootRot[2] > M_PI)
            rootRot[2] -= 2 * M_PI;
        if (rootRot[2] < -M_PI)
            rootRot[2] += 2 * M_PI;
        ModelPoseSensorMeasurementPtr modelPoseSensorMeasurement(new ModelPoseSensorMeasurement(currentTimestep, rootPos, rootRot));
        outputModelPoseSensor->addSensorMeasurement(modelPoseSensorMeasurement);

        // Create KinematicSensorMeasurement
        Eigen::VectorXf jointValues(joints.size());
        for (int i = 0; i < jointValues.rows(); ++i) {
            jointValues[i] = configuration[6 + i];
        }
        KinematicSensorMeasurementPtr kinematicSensorMeasurement(new KinematicSensorMeasurement(currentTimestep, jointValues));
        outputKinematicSensor->addSensorMeasurement(kinematicSensorMeasurement);

        // Output error (without CONVERTER_OUTPUT_MARKER_DEVIATION- console output actually is not that expensive!)
        setOutputModelConfiguration(configuration);

        double avgDistance, maxDistance;
        calculateMarkerDistancesAverageMaximum(labeledMarkerData[currentTimestep], avgDistance, maxDistance);

        std::cout << "Frame #" << frameNum << " finished: max error = " << maxDistance << ", avg error = " << avgDistance << std::endl;

        // For loop
        frameNum++;
    }
    double maxJerk, avgJerk;
    std::tie(maxJerk,avgJerk) = calculateMaxAndAverageJerk(optimizedFrames);
    MMM_INFO << "Max jerk: " << maxJerk << " avg jerk: " << avgJerk << std::endl;
}


double FramewiseLocalStrategy::objectiveFunction(const std::vector<double> &configuration, std::vector<double> &grad) {
    if (!grad.empty()) {
        MMM_ERROR << "NloptConverter: Gradient computation not supported!" << std::endl;
        return 0.0;
    }

    if (configuration.size() != frameDimension) {
        MMM_ERROR << "NloptConverter: x has wrong number of frameDimensionensions (" << configuration.size() << ")!" << std::endl;
        return 0.0;
    }

    setOutputModelConfiguration(configuration);

    return calculateMarkerDistancesSquaresSum(labeledMarkerData[currentTimestep]);
}





void FramewiseLocalStrategy::setOptimizationBounds(nlopt::opt& optimizer) const {
    // Some algorithms cannot handle unconstraint components (i.e. upper/lower limit of +/- infinity)
    const double positionLowerBound = -10000.0, positionUpperBound = 10000.0;  // 10m

    // We must not limit the rotation strictly at +- pi because otherwise a local optimization algorithm can get stuck when the rotation angle overflows/underflows
    const double rotationOverflowBorder = 0.2;
    const double rotationLowerBound = -M_PI - rotationOverflowBorder, rotationUpperBound = M_PI + rotationOverflowBorder;

    std::vector<double> lowerBounds, upperBounds;

    for (int i = 0; i < 2; ++i) {  // translation & rotation vectors
        for (int j = 0; j < 3; ++j) {
            lowerBounds.push_back(i ? rotationLowerBound : positionLowerBound);
            upperBounds.push_back(i ? rotationUpperBound : positionUpperBound);
        }
    }

    for (auto joint : joints) {
        JointInfo jointInfo = outputModel->getModelNode(joint)->joint;
        lowerBounds.push_back(jointInfo.limitLo);
        upperBounds.push_back(jointInfo.limitHi);
    }

    optimizer.set_lower_bounds(lowerBounds);
    optimizer.set_upper_bounds(upperBounds);
}
